{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.43923378, -0.89837277,  0.33788246], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#自定义一个Wrapper\n",
    "class Pendulum(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, done, _, info = self.env.step(action)\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "Pendulum().reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>\n",
      "0 [ 0.96149164  0.27483416 -0.48803374] [0.16144127] -0.10135990478688091\n",
      "20 [0.6045405 0.7965744 3.1099813] [-0.5112315] -1.8168199233925673\n",
      "40 [ 0.8752886 -0.4836009  0.6653375] [1.6132678] -0.3016567703344557\n",
      "60 [-0.6241136   0.78133357 -6.964356  ] [-1.4979022] -9.8915640814747\n",
      "80 [0.72487175 0.6888838  1.829814  ] [-1.1534948] -0.9136735397002199\n",
      "100 [ 0.880359   -0.47430798  2.8805146 ] [0.68453074] -1.0744166950065572\n",
      "120 [-0.11632711  0.993211    5.589544  ] [1.905102] -5.975205760090755\n",
      "140 [0.8625395 0.5059897 3.6681533] [-1.895143] -1.6305873916663607\n",
      "160 [ 0.9960735  -0.08853003  3.3771992 ] [1.7136339] -1.1513421518622369\n",
      "180 [ 0.6618824 -0.7496077  2.3583102] [0.21574916] -1.2744131938051289\n",
      "200 [-0.6627036  0.7488818 -7.344068 ] [-1.1858194] -10.662972460265843\n"
     ]
    }
   ],
   "source": [
    "#测试一个环境\n",
    "def test(env, wrap_action_in_list=False):\n",
    "    print(env)\n",
    "\n",
    "    state = env.reset()\n",
    "    over = False\n",
    "    step = 0\n",
    "\n",
    "    while not over:\n",
    "        action = env.action_space.sample()\n",
    "\n",
    "        if wrap_action_in_list:\n",
    "            action = [action]\n",
    "\n",
    "        next_state, reward, over, _ = env.step(action)\n",
    "\n",
    "        if step % 20 == 0:\n",
    "            print(step, state, action, reward)\n",
    "\n",
    "        if step > 200:\n",
    "            break\n",
    "\n",
    "        state = next_state\n",
    "        step += 1\n",
    "\n",
    "\n",
    "test(Pendulum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Eb2U4_K6SNUx"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<StepLimitWrapper<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>>\n",
      "0 [-0.7358193  -0.67717797 -0.40382123] [-0.5497488] -5.765440561412341\n",
      "20 [-0.9347265   0.35536796  1.9101306 ] [-0.1730721] -8.08375711348729\n",
      "40 [-0.9913878   0.13095891 -2.549925  ] [-0.81193155] -9.712515159385875\n",
      "60 [-0.86229056 -0.5064138   1.1561615 ] [0.8439644] -6.949468564700318\n",
      "80 [-0.6871128   0.72655076 -0.10954288] [-1.3248858] -5.423954906875421\n"
     ]
    }
   ],
   "source": [
    "#修改最大步数\n",
    "class StepLimitWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self, env):\n",
    "        super().__init__(env)\n",
    "        self.current_step = 0\n",
    "\n",
    "    def reset(self):\n",
    "        self.current_step = 0\n",
    "        return self.env.reset()\n",
    "\n",
    "    def step(self, action):\n",
    "        self.current_step += 1\n",
    "        state, reward, done, info = self.env.step(action)\n",
    "\n",
    "        #修改done字段\n",
    "        if self.current_step >= 100:\n",
    "            done = True\n",
    "\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "test(StepLimitWrapper(Pendulum()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "F5E6kZfzW8vy"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<NormalizeActionWrapper<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>>\n",
      "0 [-0.90479934  0.42583814  0.46141297] [0.71307397] -7.3225321184911145\n",
      "20 [-0.9846407  -0.17459278 -2.415618  ] [-0.85981315] -9.384230063667319\n",
      "40 [-0.99319226 -0.11648658  3.4871514 ] [-0.6373727] -10.367310571865762\n",
      "60 [-0.86642474  0.49930772 -3.734519  ] [0.1263556] -8.252804318413528\n",
      "80 [-0.6702318 -0.7421518  1.911543 ] [0.43951586] -5.680660931112576\n",
      "100 [-0.5961345  0.8028846 -1.0250479] [0.12384028] -4.9869101896966495\n",
      "120 [-0.64067    -0.76781636 -1.2842121 ] [-0.6126538] -5.301933755735472\n",
      "140 [-0.7986208  0.6018345  2.6601052] [-0.72427845] -6.938714204658353\n",
      "160 [-0.9983815   0.05687125 -3.7444353 ] [-0.37042484] -10.917945192112665\n",
      "180 [-0.9121499 -0.4098567  3.2691987] [-0.35181427] -8.463830116640269\n",
      "200 [-0.71257746  0.70159346 -1.1502234 ] [0.67964196] -5.722462683064076\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "#修改动作空间\n",
    "class NormalizeActionWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self, env):\n",
    "        #获取动作空间\n",
    "        action_space = env.action_space\n",
    "\n",
    "        #动作空间必须是连续值\n",
    "        assert isinstance(action_space, gym.spaces.Box)\n",
    "\n",
    "        #重新定义动作空间,在正负一之间的连续值\n",
    "        #这里其实只影响env.action_space.sample的返回结果\n",
    "        #实际在计算时,还是正负2之间计算的\n",
    "        env.action_space = gym.spaces.Box(low=-1,\n",
    "                                          high=1,\n",
    "                                          shape=action_space.shape,\n",
    "                                          dtype=np.float32)\n",
    "\n",
    "        super().__init__(env)\n",
    "\n",
    "    def reset(self):\n",
    "        return self.env.reset()\n",
    "\n",
    "    def step(self, action):\n",
    "        #重新缩放动作的值域\n",
    "        action = action * 2.0\n",
    "\n",
    "        if action > 2.0:\n",
    "            action = 2.0\n",
    "\n",
    "        if action < -2.0:\n",
    "            action = -2.0\n",
    "\n",
    "        return self.env.step(action)\n",
    "\n",
    "\n",
    "test(NormalizeActionWrapper(Pendulum()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "bBlS9YxYSpJn"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<StateStepWrapper<Pendulum<TimeLimit<OrderEnforcing<PassiveEnvChecker<PendulumEnv<Pendulum-v1>>>>>>>\n",
      "0 [ 0.59643936 -0.80265814  0.02045725  0.        ] [0.0625756] -0.8681826781206066\n",
      "20 [ 0.69300234  0.72093534 -2.91545701  0.2       ] [0.48333085] -1.4984907639906544\n",
      "40 [-0.40385997  0.91482079  5.91973925  0.4       ] [-0.16275077] -7.450653647426798\n",
      "60 [0.8851468  0.46531188 2.53248    0.6       ] [1.4304487] -0.8776350784009858\n",
      "80 [ 0.82157564 -0.57009953  1.69356799  0.8       ] [0.3554661] -0.6549398995047444\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/pt39/lib/python3.9/site-packages/gym/spaces/box.py:127: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n",
      "  logger.warn(f\"Box bound precision lowered by casting to {self.dtype}\")\n"
     ]
    }
   ],
   "source": [
    "from gym.wrappers import TimeLimit\n",
    "\n",
    "\n",
    "#修改状态\n",
    "class StateStepWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self, env):\n",
    "\n",
    "        #状态空间必须是连续值\n",
    "        assert isinstance(env.observation_space, gym.spaces.Box)\n",
    "\n",
    "        #增加一个新状态字段\n",
    "        low = np.concatenate([env.observation_space.low, [0.0]])\n",
    "        high = np.concatenate([env.observation_space.high, [1.0]])\n",
    "\n",
    "        env.observation_space = gym.spaces.Box(low=low,\n",
    "                                               high=high,\n",
    "                                               dtype=np.float32)\n",
    "\n",
    "        super().__init__(env)\n",
    "\n",
    "        self.step_current = 0\n",
    "\n",
    "    def reset(self):\n",
    "        self.step_current = 0\n",
    "        return np.concatenate([self.env.reset(), [0.0]])\n",
    "\n",
    "    def step(self, action):\n",
    "        self.step_current += 1\n",
    "        state, reward, done, info = self.env.step(action)\n",
    "\n",
    "        #根据step_max修改done\n",
    "        if self.step_current >= 100:\n",
    "            done = True\n",
    "\n",
    "        return self.get_state(state), reward, done, info\n",
    "\n",
    "    def get_state(self, state):\n",
    "        #添加一个新的state字段\n",
    "        state_step = self.step_current / 100\n",
    "\n",
    "        return np.concatenate([state, [state_step]])\n",
    "\n",
    "\n",
    "test(StateStepWrapper(Pendulum()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8cxnE5bdaQ_3",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cpu device\n",
      "------------------------------------\n",
      "| time/                 |          |\n",
      "|    fps                | 919      |\n",
      "|    iterations         | 100      |\n",
      "|    time_elapsed       | 0        |\n",
      "|    total_timesteps    | 500      |\n",
      "| train/                |          |\n",
      "|    entropy_loss       | -1.44    |\n",
      "|    explained_variance | -0.00454 |\n",
      "|    learning_rate      | 0.0007   |\n",
      "|    n_updates          | 99       |\n",
      "|    policy_loss        | -43.2    |\n",
      "|    std                | 1.02     |\n",
      "|    value_loss         | 976      |\n",
      "------------------------------------\n",
      "------------------------------------\n",
      "| time/                 |          |\n",
      "|    fps                | 897      |\n",
      "|    iterations         | 200      |\n",
      "|    time_elapsed       | 1        |\n",
      "|    total_timesteps    | 1000     |\n",
      "| train/                |          |\n",
      "|    entropy_loss       | -1.43    |\n",
      "|    explained_variance | 4.43e-05 |\n",
      "|    learning_rate      | 0.0007   |\n",
      "|    n_updates          | 199      |\n",
      "|    policy_loss        | -28.2    |\n",
      "|    std                | 1.01     |\n",
      "|    value_loss         | 936      |\n",
      "------------------------------------\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<stable_baselines3.a2c.a2c.A2C at 0x7f94cc7fee80>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from stable_baselines3 import A2C\n",
    "from stable_baselines3.common.monitor import Monitor\n",
    "from stable_baselines3.common.vec_env import DummyVecEnv\n",
    "\n",
    "#使用Monitor Wrapper,会在训练的过程中输出rollout/ep_len_mean和rollout/ep_rew_mean,就是增加些日志\n",
    "#gym升级到0.26以后失效了,可能是因为使用了自定义的wapper\n",
    "env = DummyVecEnv([lambda: Monitor(Pendulum())])\n",
    "\n",
    "A2C('MlpPolicy', env, verbose=1).learn(1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "zuIcbfv3g9dd"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<stable_baselines3.common.vec_env.vec_normalize.VecNormalize object at 0x7f949b273e20>\n",
      "0 [[-0.00487219  0.00638567  0.00554428]] [array([1.8279244], dtype=float32)] [-10.]\n",
      "20 [[-0.03349784 -0.66757905 -2.173565  ]] [array([-1.3404763], dtype=float32)] [-0.16482905]\n",
      "40 [[-1.4011567   0.07794451  1.3022798 ]] [array([1.6657785], dtype=float32)] [-0.16956142]\n",
      "60 [[-1.3572015   0.41527337 -1.5391531 ]] [array([1.3103601], dtype=float32)] [-0.12919044]\n",
      "80 [[-0.34073314 -0.9262497   1.2184559 ]] [array([0.99134326], dtype=float32)] [-0.07326685]\n",
      "100 [[ 1.6391766  1.4822493 -1.3587774]] [array([-1.5195391], dtype=float32)] [-0.04477553]\n",
      "120 [[-0.01510759 -1.150826    1.7960454 ]] [array([-0.46556672], dtype=float32)] [-0.07235025]\n",
      "140 [[-0.5499776  -0.83480734 -2.1066453 ]] [array([-0.03180405], dtype=float32)] [-0.08347733]\n",
      "160 [[ 2.0203962  -0.5923113  -0.62100804]] [array([1.302856], dtype=float32)] [-0.00734869]\n",
      "180 [[ 1.8016039   0.8581704  -0.43550822]] [array([1.2153829], dtype=float32)] [-0.00819429]\n",
      "200 [[-1.0306736   0.62731665  2.0643365 ]] [array([-0.90605944], dtype=float32)] [-0.09977578]\n"
     ]
    }
   ],
   "source": [
    "from stable_baselines3.common.vec_env import VecNormalize, VecFrameStack\n",
    "\n",
    "#VecNormalize,他会对state和reward进行Normalize\n",
    "env = DummyVecEnv([Pendulum])\n",
    "env = VecNormalize(env)\n",
    "\n",
    "test(env, wrap_action_in_list=True)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "2_gym_wrappers_saving_loading.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
